{
 "cells": [
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Import Libraries"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "import cv2 \n",
    "import pickle\n",
    "import matplotlib.pyplot as plt"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "cv2.imread(\"data/images/taj1.jpg\",0)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Resize Images function"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Resize images to a similar dimension\n",
    "# This helps improve accuracy and decreases unnecessarily high number of keypoints\n",
    "\n",
    "def imageResizeTrain(image):\n",
    "    maxD = 1024\n",
    "    height,width = image.shape\n",
    "    aspectRatio = width/height\n",
    "    if aspectRatio < 1:\n",
    "        newSize = (int(maxD*aspectRatio),maxD)\n",
    "    else:\n",
    "        newSize = (maxD,int(maxD/aspectRatio))\n",
    "    image = cv2.resize(image,newSize)\n",
    "    return image\n",
    "\n",
    "def imageResizeTest(image):\n",
    "    maxD = 1024\n",
    "    height,width,channel = image.shape\n",
    "    aspectRatio = width/height\n",
    "    if aspectRatio < 1:\n",
    "        newSize = (int(maxD*aspectRatio),maxD)\n",
    "    else:\n",
    "        newSize = (maxD,int(maxD/aspectRatio))\n",
    "    image = cv2.resize(image,newSize)\n",
    "    return image"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Generate Keypoint and Descriptors"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Prepare list of images"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Define a list of images the way you like\n",
    "\n",
    "imageList = [\"taj1.jpg\",\"taj2.jpg\",\"eiffel1.jpg\",\"eiffel2.jpg\",\"liberty1.jpg\",\"liberty2.jpg\",\"robert1.jpg\",\"tom1.jpg\",\"ironman1.jpg\",\"ironman2.jpg\",\"ironman3.jpg\",\"darkknight1.jpg\",\"darkknight2.jpg\",\"book1.jpg\",\"book2.jpg\"]"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# We use grayscale images for generating keypoints\n",
    "imagesBW = []\n",
    "for imageName in imageList:\n",
    "    imagePath = \"data/images/\" + str(imageName)\n",
    "    imagesBW.append(imageResizeTrain(cv2.imread(imagePath,0))) # flag 0 means grayscale"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Gray-scale images improve speed\n",
    "\n",
    "# imageBW = []\n",
    "# for image in images:\n",
    "#     imageBW.append(cv2.cvtColor(image, cv2.COLOR_BGR2GRAY))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Choose between opencv's SIFT and open source SIFT implementation<br>"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Using opencv's sift implementation here\n",
    "\n",
    "sift = cv2.SIFT_create()\n",
    "\n",
    "def computeSIFT(image):\n",
    "    return sift.detectAndCompute(image, None)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# # Using custom (rmislam/PythonSIFT) sift implementation\n",
    "# # Comparitively slower than opencv's implementation\n",
    "# # but source code is present in repo, hence easily modifiable\n",
    "\n",
    "# import pysift\n",
    "\n",
    "# def computeSIFT(image):\n",
    "#     return pysift.computeKeypointsAndDescriptors(image)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### The following is the main function to generate the keypoints and descriptors<br>\n",
    "When using SIFT, this takes a lot of time to compute.<br>\n",
    "Thus, it is suggested, you store the values once computed<br>\n",
    "(Code for storing is written below)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "keypoints = []\n",
    "descriptors = []\n",
    "for i,image in enumerate(imagesBW):\n",
    "    print(\"Starting for image: \" + imageList[i])\n",
    "    keypointTemp, descriptorTemp = computeSIFT(image)\n",
    "    keypoints.append(keypointTemp)\n",
    "    descriptors.append(descriptorTemp)\n",
    "    print(\"  Ending for image: \" + imageList[i])\n"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Store Keypoints and Descriptors for future use"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i,keypoint in enumerate(keypoints):\n",
    "    deserializedKeypoints = []\n",
    "    filepath = \"data/keypoints/\" + str(imageList[i].split('.')[0]) + \".txt\"\n",
    "    for point in keypoint:\n",
    "        temp = (point.pt, point.size, point.angle, point.response, point.octave, point.class_id)\n",
    "        deserializedKeypoints.append(temp)\n",
    "    with open(filepath, 'wb') as fp:\n",
    "        pickle.dump(deserializedKeypoints, fp)    \n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "for i,descriptor in enumerate(descriptors):\n",
    "    filepath = \"data/descriptors/\" + str(imageList[i].split('.')[0]) + \".txt\"\n",
    "    with open(filepath, 'wb') as fp:\n",
    "        pickle.dump(descriptor, fp)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Prepare for fetching results"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Fetch Keypoints and Descriptors from stored files"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fetchKeypointFromFile(i):\n",
    "    filepath = \"data/keypoints/\" + str(imageList[i].split('.')[0]) + \".txt\"\n",
    "    keypoint = []\n",
    "    file = open(filepath,'rb')\n",
    "    deserializedKeypoints = pickle.load(file)\n",
    "    file.close()\n",
    "    for point in deserializedKeypoints:\n",
    "        temp = cv2.KeyPoint(\n",
    "            x=point[0][0],\n",
    "            y=point[0][1],\n",
    "            size=point[1],\n",
    "            angle=point[2],\n",
    "            response=point[3],\n",
    "            octave=point[4],\n",
    "            class_id=point[5]\n",
    "        )\n",
    "        keypoint.append(temp)\n",
    "    return keypoint"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def fetchDescriptorFromFile(i):\n",
    "    filepath = \"data/descriptors/\" + str(imageList[i].split('.')[0]) + \".txt\"\n",
    "    file = open(filepath,'rb')\n",
    "    descriptor = pickle.load(file)\n",
    "    file.close()\n",
    "    return descriptor"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Calculate Results for any pair"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculateResultsFor(i,j):\n",
    "    keypoint1 = fetchKeypointFromFile(i)\n",
    "    descriptor1 = fetchDescriptorFromFile(i)\n",
    "    keypoint2 = fetchKeypointFromFile(j)\n",
    "    descriptor2 = fetchDescriptorFromFile(j)\n",
    "    matches = calculateMatches(descriptor1, descriptor2)\n",
    "    score = calculateScore(len(matches),len(keypoint1),len(keypoint2))\n",
    "    plot = getPlotFor(i,j,keypoint1,keypoint2,matches)\n",
    "    print(len(matches),len(keypoint1),len(keypoint2),len(descriptor1),len(descriptor2))\n",
    "    print(score)\n",
    "    plt.imshow(plot),plt.show()"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getPlotFor(i,j,keypoint1,keypoint2,matches):\n",
    "    image1 = imageResizeTest(cv2.imread(\"data/images/\" + imageList[i]))\n",
    "    image2 = imageResizeTest(cv2.imread(\"data/images/\" + imageList[j]))\n",
    "    return getPlot(image1,image2,keypoint1,keypoint2,matches)"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Basic Scoring metric\n",
    "A score greater than 10 means very good"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def calculateScore(matches,keypoint1,keypoint2):\n",
    "    return 100 * (matches/min(keypoint1,keypoint2))"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Use this part of code for brute force matching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# bf = cv2.BFMatcher(cv2.NORM_L1, crossCheck=True)\n",
    "# def calculateMatches(descriptor1,descriptor2):\n",
    "#     matches = bf.match(descriptor1,descriptor2)\n",
    "#     matches = sorted(matches, key = lambda x:x.distance)\n",
    "#     return matches"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# def getPlot(image1,image2,keypoint1,keypoint2,matches):\n",
    "#     matchPlot = cv2.drawMatches(image1, keypoint1, image2, keypoint2, matches[:50], image2, flags=2)\n",
    "#     return matchPlot"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Use this part of code for knn matching"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "bf = cv2.BFMatcher()\n",
    "def calculateMatches(des1,des2):\n",
    "    matches = bf.knnMatch(des1,des2,k=2)\n",
    "    topResults1 = []\n",
    "    for m,n in matches:\n",
    "        if m.distance < 0.7*n.distance:\n",
    "            topResults1.append([m])\n",
    "            \n",
    "    matches = bf.knnMatch(des2,des1,k=2)\n",
    "    topResults2 = []\n",
    "    for m,n in matches:\n",
    "        if m.distance < 0.7*n.distance:\n",
    "            topResults2.append([m])\n",
    "    \n",
    "    topResults = []\n",
    "    for match1 in topResults1:\n",
    "        match1QueryIndex = match1[0].queryIdx\n",
    "        match1TrainIndex = match1[0].trainIdx\n",
    "\n",
    "        for match2 in topResults2:\n",
    "            match2QueryIndex = match2[0].queryIdx\n",
    "            match2TrainIndex = match2[0].trainIdx\n",
    "\n",
    "            if (match1QueryIndex == match2TrainIndex) and (match1TrainIndex == match2QueryIndex):\n",
    "                topResults.append(match1)\n",
    "    return topResults"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "# Use this if you want faster results and can give up some accuracy\n",
    "# bf = cv2.BFMatcher()\n",
    "# def calculateMatches(des1,des2):\n",
    "#     matches = bf.knnMatch(des1,des2,k=2)\n",
    "#     topResults = []\n",
    "#     for m,n in matches:\n",
    "#         if m.distance < 0.7*n.distance:\n",
    "#             topResults.append([m])\n",
    "#     return topResults"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "def getPlot(image1,image2,keypoint1,keypoint2,matches):\n",
    "    image1 = cv2.cvtColor(image1, cv2.COLOR_BGR2RGB)\n",
    "    image2 = cv2.cvtColor(image2, cv2.COLOR_BGR2RGB)\n",
    "    matchPlot = cv2.drawMatchesKnn(\n",
    "        image1,\n",
    "        keypoint1,\n",
    "        image2,\n",
    "        keypoint2,\n",
    "        matches,\n",
    "        None,\n",
    "        [255,255,255],\n",
    "        flags=2\n",
    "    )\n",
    "    return matchPlot"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "## Fetch Results"
   ]
  },
  {
   "attachments": {},
   "cell_type": "markdown",
   "metadata": {},
   "source": [
    "### Sample Results"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": [
    "calculateResultsFor(13,14)\n"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": null,
   "metadata": {},
   "outputs": [],
   "source": []
  }
 ],
 "metadata": {
  "kernelspec": {
   "display_name": "Python 3 (ipykernel)",
   "language": "python",
   "name": "python3"
  },
  "language_info": {
   "codemirror_mode": {
    "name": "ipython",
    "version": 3
   },
   "file_extension": ".py",
   "mimetype": "text/x-python",
   "name": "python",
   "nbconvert_exporter": "python",
   "pygments_lexer": "ipython3",
   "version": "3.9.13"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 4
}
